package com.bosssoft.platform.fasttcc.impl;

import com.bosssoft.platform.fasttcc.*;
import com.jfireframework.baseutil.TRACEID;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicInteger;

public class TccTransactionImpl extends AtomicInteger implements TccTransaction, CompleteStageListener
{
    private final        Xid                         xid;
    private final        byte                        role;
    private final        String                      propagatedBy;
    private final        TccLogger                   tccLogger;
    private final        TccTransactionManager       manager;
    private              ExecutorService             executorService;
    private              boolean                     async         = false;
    private              LocalTransaction            currentLocalTx;
    /**
     * 必然存在，提前初始化
     */
    private              List<LocalTransaction>      localTxList   = new ArrayList<>(4);
    private              List<TccInvoke>             tccInvokeList = new ArrayList<>(4);
    private              List<RemoteResourceContext> remoteResourceContexts;
    private volatile     int                         state;
    private static final Logger                      LOGGER        = LoggerFactory.getLogger(TccTransactionImpl.class);

    public TccTransactionImpl(Xid xid, TccLogger tccLogger, TccTransactionManager manager)
    {
        super(-1);
        role = COORDINATOR;
        propagatedBy = null;
        this.xid = xid;
        this.tccLogger = tccLogger;
        this.manager = manager;
        state = ACTIVE;
    }

    public TccTransactionImpl(Xid xid, TccLogger tccLogger, TccTransactionManager manager, String propagatedBy)
    {
        super(-1);
        role = PARTICIPATOR;
        this.propagatedBy = propagatedBy;
        this.xid = xid;
        this.tccLogger = tccLogger;
        this.manager = manager;
        state = ACTIVE;
    }

    public void setAsync(ExecutorService executorService)
    {
        this.async = true;
        this.executorService = executorService;
    }

    @Override
    public Xid getCurrentLocalTransactionXid()
    {
        return currentLocalTx == null ? null : currentLocalTx.xid();
    }

    @Override
    public boolean isFirstLocalTransactionCommited()
    {
        return localTxList.get(0).isCommited();
    }

    @Override
    public byte role()
    {
        return role;
    }

    public void registerLocalTransaction(LocalTransaction localTransaction)
    {
        localTransaction.setPreLocalTransaction(currentLocalTx);
        localTransaction.setIndex(localTxList.size());
        localTxList.add(localTransaction);
        currentLocalTx = localTransaction;
        tccLogger.registerLocalTransaction(currentLocalTx);
    }

    public void registerTccInvoke(TccInvoke tccInvoke)
    {
        tccInvoke.associateLocalTransaction(currentLocalTx);
        tccInvokeList.add(tccInvoke);
        tccLogger.registerTccInvoke(tccInvoke, this);
    }

    public void registerRemoteResource(RemoteResource remoteResource)
    {
        String traceId = TRACEID.currentTraceId();
        if (remoteResourceContexts == null)
        {
            remoteResourceContexts = new ArrayList<>(4);
        }
        for (RemoteResourceContext context : remoteResourceContexts)
        {
            if (context.remoteResource.isSameInstance(remoteResource))
            {
                LOGGER.debug("traceId:{} 远端资源:{}与已经入列的远端资源:{}是相同的远端实例，本次不入列", traceId, remoteResource.getIdentifier(), context.remoteResource.getIdentifier());
                return;
            }
        }
        LOGGER.debug("traceId:{} 远端资源:{}入列事务:{}", traceId, remoteResource.getIdentifier(), xid);
        RemoteResourceContext context = new RemoteResourceContext();
        context.remoteResource = remoteResource;
        context.completed = false;
        remoteResourceContexts.add(context);
        tccLogger.registerRemoteResource(remoteResource, xid);
    }

    @Override
    public void commitCurrentLocalTransaction()
    {
        LOGGER.debug("traceId:{} 当前本地事务xid:{}提交成功", TRACEID.currentTraceId(), currentLocalTx.xid());
        currentLocalTx.setCommited();
        resumePreLocalTransaction();
    }

    private void resumePreLocalTransaction()
    {
        int preLocalTxIndex = currentLocalTx.preLocalTxIndex();
        if (preLocalTxIndex != -1)
        {
            LocalTransaction pred = localTxList.get(preLocalTxIndex);
            currentLocalTx = pred;
        }
        else
        {
            currentLocalTx = null;
        }
    }

    @Override
    public void rollbackCurrentLocalTransaction()
    {
        LOGGER.debug("traceId:{} 当前本地事务xid:{}回滚完成", TRACEID.currentTraceId(), currentLocalTx.xid());
        currentLocalTx.setRollbacked();
        resumePreLocalTransaction();
    }

    @Override
    public void processCompleteStage()
    {
        String traceId = TRACEID.currentTraceId();
        if (state == ACTIVE)
        {
            throw new IllegalStateException();
        }
        if (state != FINISHED)
        {
            int numOfHandler = calculateNumOfHandler();
            if (get() == -1 && compareAndSet(-1, numOfHandler))
            {
                if (numOfHandler == 0)
                {
                    LOGGER.debug("traceId:{} TCC事务:{}当前没有剩余任务，标记状态为完成", traceId, xid);
                    markForComplete();
                    set(-1);
                    return;
                }
                if (state == MARK_FOR_COMMIT)
                {
                    LOGGER.debug("traceId:{} TCC事务:{}执行完成阶段提交分支，当前操作数:{}", traceId, xid, numOfHandler);
                    doCommit();
                }
                else
                {
                    LOGGER.debug("traceId:{} TCC事务:{}执行完成阶段回滚分支，当前操作数:{}", traceId, xid, numOfHandler);
                    doRollback();
                }
            }
        }
        else if (state == FINISHED)
        {
            LOGGER.debug("traceId:{} 事务:{}当前是结束状态，无变化", traceId, xid);
        }
    }

    private int calculateNumOfHandler()
    {
        int numOfHandler = 0;
        for (TccInvoke each : tccInvokeList)
        {
            if (each.isCompleted() == false)
            {
                numOfHandler++;
            }
        }
        if (remoteResourceContexts != null)
        {
            for (RemoteResourceContext each : remoteResourceContexts)
            {
                if (each.completed == false)
                {
                    numOfHandler++;
                }
            }
        }
        return numOfHandler;
    }

    private void doRollback()
    {
        for (final TccInvoke tccInvoke : tccInvokeList)
        {
            if (tccInvoke.isCompleted())
            {
                continue;
            }
            //由于本地事务回滚了，没有实际写入。此时cancel分支就没有执行的必要了，直接标记为完成阶段完成即可
            if (tccInvoke.getAssociatedLocalTransaction().isRollbacked())
            {
                tccInvoke.markCompleted();
                decrementAndGet();
                continue;
            }
            tccInvoke.rollback(async, this);
        }
        if (remoteResourceContexts != null)
        {
            for (final RemoteResourceContext each : remoteResourceContexts)
            {
                if (each.completed == false)
                {
                    each.remoteResource.rollback(async, xid, this);
                }
            }
        }
    }

    private void doCommit()
    {
        String traceId = TRACEID.currentTraceId();
        for (final TccInvoke tccInvoke : tccInvokeList)
        {
            if (tccInvoke.isCompleted() == false)
            {
                TccOperation tccOperation = tccInvoke.getTccOperation();
                LOGGER.debug("traceId:{} tcc操作:{}.{}提交执行确认分支，确认分支xid:{}", traceId, tccOperation.getTryClass().getName(), tccOperation.getTryMethod().getName(), tccInvoke.getCompleteStageXid());
                tccInvoke.commit(async, this);
            }
        }
        if (remoteResourceContexts != null)
        {
            for (final RemoteResourceContext each : remoteResourceContexts)
            {
                if (each.completed == false)
                {
                    LOGGER.debug("traceId:{} 远端资源:{}提交确认请求，TCC事务Xid:{}", traceId, each.remoteResource.getIdentifier(), xid);
                    each.remoteResource.commit(async, xid, this);
                }
            }
        }
    }

    @Override
    public int getState()
    {
        return state;
    }

    @Override
    public void markForCommit()
    {
        if (state != ACTIVE)
        {
            throw new IllegalStateException("非激活态不可以提交，当前是" + state);
        }
        state = MARK_FOR_COMMIT;
        tccLogger.updateTccTransactionState(this);
    }

    @Override
    public void markForRollback()
    {
        if (state != ACTIVE)
        {
            throw new IllegalStateException();
        }
        state = MARK_FOR_ROLLBACK;
        tccLogger.updateTccTransactionState(this);
    }

    public void markForComplete()
    {
        if (state != MARK_FOR_COMMIT && state != MARK_FOR_ROLLBACK)
        {
            throw new IllegalStateException();
        }
        state = FINISHED;
        tccLogger.updateTccTransactionState(this);
    }


    @Override
    public String getPropagatedBy()
    {
        return propagatedBy;
    }

    @Override
    public Xid getXid()
    {
        return xid;
    }

    public void resetLocalTransactions(List<LocalTransaction> localTransactions)
    {
        this.localTxList = localTransactions;
    }

    public void resetTccInvokes(List<TccInvoke> tccInvokes)
    {
        this.tccInvokeList = tccInvokes;
    }

    public void resetRemoteResources(List<RemoteResource> remoteResources)
    {
        remoteResourceContexts = new ArrayList<>();
        for (RemoteResource remoteResource : remoteResources)
        {
            RemoteResourceContext context = new RemoteResourceContext();
            context.remoteResource = remoteResource;
            context.completed = false;
            remoteResourceContexts.add(context);
        }
    }

    @Override
    public void onRemoteFinish(RemoteResource remoteResource, boolean completed)
    {
        String traceId = TRACEID.currentTraceId();
        if (completed)
        {
            for (RemoteResourceContext context : remoteResourceContexts)
            {
                if (context.remoteResource.isSameInstance(remoteResource))
                {
                    context.completed = true;
                }
            }
        }
        onFinish();
    }

    @Override
    public void onFinish()
    {
        String traceId = TRACEID.currentTraceId();
        int    left    = decrementAndGet();
        if (left != 0)
        {
            LOGGER.debug("traceId:{} TCC事务:{}当前完成阶段尚未执行完毕，剩余操作数:{}", traceId, xid, left);
            return;
        }
        LOGGER.debug("traceId:{} TCC事务:{}完成阶段执行完毕", traceId, xid);
        boolean finish = true;
        for (TccInvoke tccInvoke : tccInvokeList)
        {
            if (tccInvoke.isCompleted() == false)
            {
                LOGGER.debug("traceId:{} TCC调用:{}完成阶段执行失败，TCC事务无法进入结束状态", traceId, tccInvoke.getCompleteStageXid());
                finish = false;
                break;
            }
        }
        if (remoteResourceContexts != null)
        {
            for (RemoteResourceContext context : remoteResourceContexts)
            {
                if (context.completed == false)
                {
                    LOGGER.debug("traceId:{} 远端资源:{}完成阶段执行失败,TCC事务无法进入结束状态", traceId, context.remoteResource.getIdentifier());
                    finish = false;
                    break;
                }
            }
        }
        if (finish)
        {
            LOGGER.debug("traceId:{} TCC事务:{}顺利完成，标记状态为完成", traceId, xid);
            markForComplete();
            set(-1);
        }
        else
        {
            LOGGER.debug("traceId:{} TCC事务:{}完成阶段执行失败，加入重试集合", traceId, xid);
            manager.addReCompleteTransaction(this);
            set(-1);
        }
    }

    public void setState(int state)
    {
        this.state = state;
    }

    class RemoteResourceContext
    {
        RemoteResource remoteResource;
        volatile boolean completed;
    }
}
